{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0be845da",
   "metadata": {},
   "source": [
    "## Set-up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33681dd1",
   "metadata": {},
   "source": [
    "Necessary imports and helper functions for displaying points, boxes, and masks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%cd '/home/xiaobo/Project/Trans-WSSS'\n",
    "%pwd\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"others/segment_anything\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69b28288",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29bc90d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_mask(mask, ax, random_color=False):\n",
    "    if random_color:\n",
    "        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
    "    else:\n",
    "        color = np.array([30/255, 144/255, 255/255, 0.6])\n",
    "    h, w = mask.shape[-2:]\n",
    "    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
    "    ax.imshow(mask_image)\n",
    "    \n",
    "def show_points(coords, labels, ax, marker_size=375):\n",
    "    pos_points = coords[labels==1]\n",
    "    neg_points = coords[labels==0]\n",
    "    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
    "    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   \n",
    "    \n",
    "def show_box(box, ax):\n",
    "    x0, y0 = box[0], box[1]\n",
    "    w, h = box[2] - box[0], box[3] - box[1]\n",
    "    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23842fb2",
   "metadata": {},
   "source": [
    "## Example image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c2e4f6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = cv2.imread('others/segment-anything/notebooks/images/truck.jpg')\n",
    "image = cv2.imread('datasets/VOC2012/JPEGImages/2007_000480.jpg')\n",
    "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30125fd",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98b228b8",
   "metadata": {},
   "source": [
    "## Selecting objects with SAM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bb1927b",
   "metadata": {},
   "source": [
    "First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e28150b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from segment_anything import sam_model_registry, SamPredictor\n",
    "\n",
    "sam_checkpoint = \"pretrains/SAM/sam_vit_h_4b8939.pth\"\n",
    "model_type = \"vit_h\"\n",
    "\n",
    "device = \"cuda:0\"\n",
    "\n",
    "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
    "sam.to(device=device)\n",
    "\n",
    "predictor = SamPredictor(sam)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c925e829",
   "metadata": {},
   "source": [
    "Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d95d48dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.set_image(image)\n",
    "low_res_pad_h, low_res_pad_w = (256 - predictor.input_size[0] // 4), (256 - predictor.input_size[1] // 4)\n",
    "print(f\"low_res_pad_h: {low_res_pad_h}, low_res_pad_w: {low_res_pad_w}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8fc7a46",
   "metadata": {},
   "source": [
    "To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c69570c",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = np.array([[48, 226]])\n",
    "input_label = np.array([1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a91ba973",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('on')\n",
    "plt.show()  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c765e952",
   "metadata": {},
   "source": [
    "Predict with `SamPredictor.predict`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5373fd68",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, scores, logits = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    multimask_output=True,\n",
    ")\n",
    "multi_logits = logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "mask_logits, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    multimask_output=True,\n",
    "    return_logits=True\n",
    ")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "id": "c7f0e938",
   "metadata": {},
   "source": [
    "With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47821187",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(masks.shape)  # (number_of_masks) x H x W\n",
    "print(scores.shape)\n",
    "print(logits.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(f\"{logits.mean()=}, {logits.std()=}, {logits.min()=}, {logits.max()=}\")\n",
    "logits[0]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 20))\n",
    "plt.imshow(logits.transpose(1, 0, 2).reshape(256, -1))\n",
    "plt.colorbar(fraction=0.02)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for i, logit in enumerate(logits):\n",
    "    print(f\"mask {i} padding logit: {logit[-low_res_pad_h:, -low_res_pad_w:].mean()=}\")\n",
    "\n",
    "for i, logit in enumerate(mask_logits):\n",
    "    print(f\"mask {i} fg logit: {logit[masks[i]].mean()=}\")\n",
    "    print(f\"mask {i} bg logit: {logit[~masks[i]].mean()=}\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plt.figure(figsize=(30, 10))\n",
    "for i, logit in enumerate(logits):\n",
    "    plt.subplot(1, 3, i + 1)\n",
    "    plt.hist(logit.flatten(), bins=50)\n",
    "    plt.yscale('log', base=10)\n",
    "    plt.tick_params(axis='both', which='major', labelsize=20)\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(mask_logits.shape)  # (number_of_masks) x H x W\n",
    "plt.figure(figsize=(20, 20))\n",
    "plt.imshow(mask_logits.transpose(1, 0, 2).reshape(image.shape[0], -1))\n",
    "plt.colorbar(fraction=0.01)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c227a6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(30,10))\n",
    "for i, (mask, score) in enumerate(zip(masks, scores)):\n",
    "    plt.subplot(1, 3, i+1)\n",
    "    plt.imshow(image)\n",
    "    show_mask(mask, plt.gca())\n",
    "    show_points(input_point, input_label, plt.gca())\n",
    "    plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
    "    plt.axis('off')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fa31f7c",
   "metadata": {},
   "source": [
    "## Specifying a specific object with additional points"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88d6d29a",
   "metadata": {},
   "source": [
    "The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6923b94",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = np.array([[48, 226], [182, 237]])\n",
    "input_label = np.array([1, 1])\n",
    "\n",
    "mask_input = multi_logits[np.argmax(scores), :, :]  # Choose the model's best mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d98f96a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, logits = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    mask_input=mask_input[None, :, :],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "mask_logits, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    mask_input=mask_input[None, :, :],\n",
    "    multimask_output=False,\n",
    "    return_logits=True\n",
    ")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce8b82f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(masks.shape)  # (number_of_masks) x H x W\n",
    "print(scores.shape)\n",
    "print(logits.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(f\"{logits.mean()=}, {logits.std()=}, {logits.min()=}, {logits.max()=}\")\n",
    "logits[0]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(logits.transpose(1, 0, 2).reshape(256, -1))\n",
    "plt.colorbar(fraction=0.1)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for i, logit in enumerate(logits):\n",
    "    print(f\"mask {i} padding logit: {logit[-low_res_pad_h:, -low_res_pad_w:].mean()=}\")\n",
    "\n",
    "for i, logit in enumerate(mask_logits):\n",
    "    print(f\"mask {i} fg logit: {logit[masks[i]].mean()=}\")\n",
    "    print(f\"mask {i} bg logit: {logit[~masks[i]].mean()=}\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "for i, logit in enumerate(logits):\n",
    "    plt.subplot(1, 1, i + 1)\n",
    "    plt.hist(logit.flatten(), bins=50)\n",
    "    plt.yscale('log', base=10)\n",
    "    plt.tick_params(axis='both', which='major', labelsize=20)\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(mask_logits.shape)  # (number_of_masks) x H x W\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(mask_logits.transpose(1, 0, 2).reshape(image.shape[0], -1))\n",
    "plt.colorbar(fraction=0.025)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e06d5c8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c93e2087",
   "metadata": {},
   "source": [
    "To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a196f68",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = np.array([[48, 226], [182, 237], [267, 345]])\n",
    "input_label = np.array([1, 1, 0])\n",
    "\n",
    "mask_input = multi_logits[np.argmax(scores), :, :]  # Choose the model's best mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81a52282",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, logits = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    mask_input=mask_input[None, :, :],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "mask_logits, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    mask_input=mask_input[None, :, :],\n",
    "    multimask_output=False,\n",
    "    return_logits=True\n",
    ")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(masks.shape)  # (number_of_masks) x H x W\n",
    "print(scores.shape)\n",
    "print(logits.shape)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(f\"{logits.mean()=}, {logits.std()=}, {logits.min()=}, {logits.max()=}\")\n",
    "logits[0]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(logits.transpose(1, 0, 2).reshape(256, -1))\n",
    "plt.colorbar(fraction=0.1)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for i, logit in enumerate(logits):\n",
    "    print(f\"mask {i} padding logit: {logit[-low_res_pad_h:, -low_res_pad_w:].mean()=}\")\n",
    "\n",
    "for i, logit in enumerate(mask_logits):\n",
    "    print(f\"mask {i} fg logit: {logit[masks[i]].mean()=}\")\n",
    "    print(f\"mask {i} bg logit: {logit[~masks[i]].mean()=}\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "for i, logit in enumerate(logits):\n",
    "    plt.subplot(1, 1, i + 1)\n",
    "    plt.hist(logit.flatten(), bins=50)\n",
    "    plt.yscale('log', base=10)\n",
    "    plt.tick_params(axis='both', which='major', labelsize=20)\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print(mask_logits.shape)  # (number_of_masks) x H x W\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(mask_logits.transpose(1, 0, 2).reshape(image.shape[0], -1))\n",
    "plt.colorbar(fraction=0.025)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfca709f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Not Used"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
